Skip to content

Implement _dequantize for TorchAO quantizer#13538

Open
jiqing-feng wants to merge 6 commits into
huggingface:mainfrom
jiqing-feng:torchao-dequantize
Open

Implement _dequantize for TorchAO quantizer#13538
jiqing-feng wants to merge 6 commits into
huggingface:mainfrom
jiqing-feng:torchao-dequantize

Conversation

@jiqing-feng
Copy link
Copy Markdown
Contributor

What does this PR do?

Implements the _dequantize() method for TorchAoHfQuantizer, enabling model.dequantize() to convert TorchAO-quantized models back to standard float weights.

Changes

  1. Add _dequantize() method: Iterates all nn.Linear modules, calls weight.dequantize() on TorchAOBaseTensor weights, replaces them with standard nn.Parameter, and resets any overridden extra_repr.

  2. Fix _verify_if_layer_quantized: Added isinstance(module.weight, TorchAOBaseTensor) check so that dequantized layers (which are still nn.Linear but with plain tensor weights) are correctly detected as non-quantized.

@github-actions github-actions Bot added quantization tests size/S PR with diff < 50 LOC labels Apr 22, 2026
@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 22, 2026
@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels May 6, 2026
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @sayakpaul . Would you please review this PR? Thanks!

Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo could you review too?

Comment on lines +826 to +828
assert isinstance(module.weight, TorchAOBaseTensor), (
f"Layer {name} weight is {type(module.weight)}, expected TorchAOBaseTensor"
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also enable dequantization tests for TorchAO tester mixin?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

- Add _dequantize() method in TorchAoHfQuantizer that dequantizes
  TorchAOBaseTensor weights back to standard nn.Parameter
- Fix _verify_if_layer_quantized to check isinstance(weight, TorchAOBaseTensor)
  so dequantized layers are correctly detected as non-quantized
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng jiqing-feng force-pushed the torchao-dequantize branch from a819214 to 9102fb8 Compare May 13, 2026 03:05
@jiqing-feng jiqing-feng requested a review from sayakpaul May 13, 2026 03:20
Copy link
Copy Markdown
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will also let @vkuzo review this.

from torchao.utils import TorchAOBaseTensor

for name, module in model.named_modules():
if isinstance(module, nn.Linear) and isinstance(module.weight, TorchAOBaseTensor):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TorchAOBaseTensor does not expose dequantize as a public API, it is defined on child classes. I agree that it would make sense to do so in the future. If you want to be safe here it might be better to check for individual tensor subclasses that do expose it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @vkuzo! You're right that dequantize() is defined on child classes rather than on TorchAOBaseTensor itself. I've added a hasattr guard so we safely skip any subclass that doesn't expose it. In practice all quantized tensor subclasses we encounter do implement dequantize(), but this makes it future-proof.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng jiqing-feng force-pushed the torchao-dequantize branch from 95d0118 to df36f1a Compare May 14, 2026 02:02
],
ids=["int4wo", "int8wo", "int8dq"],
)
def test_torchao_dequantize(self, quant_type):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the tests with the following command: ``

And there are test failures:

FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_dequantize[int4wo] - NotImplementedError: Int4Tensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload...
FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_dequantize[int8wo] - RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16
FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_dequantize[int8dq] - NotImplementedError: LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/funct...

With the following diff I managed to get it down to two:

diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py
index 840eaa338..e73c31561 100644
--- a/tests/models/transformers/test_models_transformer_flux.py
+++ b/tests/models/transformers/test_models_transformer_flux.py
@@ -367,6 +367,10 @@ class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
 class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
     """TorchAO quantization tests for Flux Transformer."""
 
+    @property
+    def torch_dtype(self):
+        return torch.bfloat16
+
 
 class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
     @property
FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_dequantize[int4wo] - NotImplementedError: Int4Tensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload...
FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_dequantize[int8dq] - NotImplementedError: LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/funct...

I am on an H100.

@jiqing-feng jiqing-feng force-pushed the torchao-dequantize branch from 2cbe719 to 83431bf Compare May 14, 2026 02:47
@jiqing-feng jiqing-feng force-pushed the torchao-dequantize branch from 83431bf to 450d0e4 Compare May 14, 2026 02:58
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Copy Markdown
Contributor Author

Hi @sayakpaul . I have fixed the dtype issue and skip [int4wo] and [int8dq] since torchao didn't implement them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

quantization size/S PR with diff < 50 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants